package shell

import (
	"fmt"
	"github.com/fatih/color"
	"github.com/pkg/sftp"
	"golang.org/x/crypto/ssh/terminal"
	"gosh/pkg/connect"
	"gosh/pkg/handler"
	"io"
	"io/fs"
	"os"
	"path"
	"path/filepath"
	"strings"
	"time"
)

func PushFiles() {
	if err := initJumpServer(); err != nil {
		color.Red("init bastion failed,err:%s\n", err.Error())
		return
	}
	handler.Handler(SSHConfig.Fork, SSHConfig.Hosts, SSHConfig.PrintStatus, SSHConfig.PrintFields, push)
}

func push(ip string) (stdout, stderr []byte, err error) {
	host := hostMap[ip]
	jumpserver := jumpserverMap[host.BastionIP]

	sshHost := connect.NewHost(ip, host.SSHUser, host.SSHPassword, host.SSHKeyContent,
		host.SSHKeyPassphrase, host.SSHPort, jumpserver, host.SSHTimeout)
	if err = sshHost.OpenSftp(); err != nil {
		return nil, nil, err
	}
	defer sshHost.Close()
	cmd := SSHConfig.Args
	return sendFiles(sshHost.SftpClient, cmd)
}

// 发送文件
func sendFiles(client *sftp.Client, args []string) ([]byte, []byte, error) {
	var sendStdout = make([]byte, 0, 1024)
	var sendStderr = make([]byte, 0, 1024)

	remoteDir := strings.TrimSuffix(args[len(args)-1], "/")
	for _, localFile := range args[:len(args)-1] {
		start := time.Now().UnixNano()
		var size int64
		_, err := os.Stat(localFile)
		if err != nil {
			return nil, nil, err
		}
		// 一个目录一个目录的拷贝，这样方便处理目标路径问题
		localFiles := make([]string, 0, 1)
		_ = filepath.Walk(localFile, func(path string, info fs.FileInfo, err error) error {
			if !info.IsDir() {
				localFiles = append(localFiles, path)
			}
			return nil
		})
		for _, file := range localFiles {
			stat, err := os.Stat(file)
			if err != nil {
				return sendStdout, sendStderr, err
			}
			oldPath := path.Dir(localFile)
			remoteFile := path.Join(remoteDir, strings.Replace(file, oldPath, "", 1))
			if oldPath == "." {
				remoteFile = path.Join(remoteDir, file)
			}
			var useLocalDirMode bool
			// 说明不需要远端主机创建目录的
			if path.Dir(remoteFile) != remoteDir {
				useLocalDirMode = true
			}
			err = sendfile(file, remoteFile, stat.Mode(), useLocalDirMode, client)
			if err != nil {
				return sendStdout, sendStderr, err
			}
			size += stat.Size()
		}
		end := time.Now().UnixNano()
		sendStdout = append(sendStdout, resFormat(size, localFile, end-start)...)
	}

	return sendStdout, sendStderr, nil
}

// sendfile 创建目录并修改权限，拷贝文件并修改权限
func sendfile(localFile, remoteFile string, fileMode fs.FileMode, useLocalDirMode bool, client *sftp.Client) error {
	localFd, err := os.Open(localFile)
	if err != nil {
		return err
	}
	defer func() { _ = localFd.Close() }()
	if useLocalDirMode {
		err = client.MkdirAll(path.Dir(remoteFile))
		if err != nil {
			return err
		}
		dirStat, err := os.Stat(path.Dir(localFile))
		if err != nil {
			return err
		}
		_ = client.Chmod(path.Dir(remoteFile), dirStat.Mode())
	}

	remoteFd, err := client.Create(remoteFile)
	if err != nil {
		return err
	}
	defer func() { _ = remoteFd.Close() }()
	_ = remoteFd.Chmod(fileMode)
	_, err = io.Copy(remoteFd, localFd)
	return err
}

func formatFileSize(s int64) string {
	if s < 1024 {
		return fmt.Sprintf("%.2fB", float64(s))
	} else if s < (1024 * 1024) {
		return fmt.Sprintf("%.2fK", float64(s)/float64(1024))
	} else if s < (1024 * 1024 * 1024) {
		return fmt.Sprintf("%.2fM", float64(s)/float64(1024*1024))
	} else {
		return fmt.Sprintf("%.2fG", float64(s)/float64(1024*1024*1024))
	}
}

func resFormat(n int64, file string, cost int64) []byte {
	size := formatFileSize(n)
	speed := formatFileSize(n*1e9/cost) + "/s"
	fd := int(os.Stdin.Fd())
	width, _, err := terminal.GetSize(fd)
	if err != nil {
		width = 120
	}
	return []byte(fillString(width-30, file, true) + fillString(15, size, false) + fillString(15, speed, false))
}

func fillString(width int, content string, left bool) string {
	i := len([]rune(content))
	if i >= width {
		return content
	}
	if left {
		return content + strings.Repeat(" ", width-i)
	}
	return strings.Repeat(" ", width-i) + content
}
